{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch Relay 前端测试\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%x: Tensor[(1, 3, 10, 10), float32] /* span=aten::_convolution_0.x:0:0 */, %aten::_convolution_0.weight: Tensor[(16, 3, 3, 3), float32] /* span=aten::_convolution_0.weight:0:0 */, %aten::_convolution_1.weight: Tensor[(32, 16, 1, 1), float32] /* span=aten::_convolution_1.weight:0:0 */) {\n",
      "  %0 = nn.conv2d(%x, %aten::_convolution_0.weight, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3]) /* span=aten::_convolution_0:0:0 */;\n",
      "  %1 = image.resize2d(%0, size=[4, 4], roi=[0f, 0f, 0f, 0f], method=\"nearest_neighbor\", coordinate_transformation_mode=\"asymmetric\", rounding_method=\"\", cubic_alpha=-0.75f) /* span=aten::upsample_nearest2d_0:0:0 */;\n",
      "  %2 = multiply(4.01903f /* span=aten::mul_0:0:0 */, %1) /* span=aten::mul_0:0:0 */;\n",
      "  nn.conv2d(%2, %aten::_convolution_1.weight, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* span=aten::_convolution_1:0:0 */\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.relay.dataflow_pattern import *\n",
    "from tvm.relay.dataflow_pattern import is_constant as is_const\n",
    "from tvm.relay.testing import run_opt_pass\n",
    "\n",
    "class RewriteMulConv2d(DFPatternCallback):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.x = wildcard()\n",
    "        self.scale = is_const()\n",
    "        self.multiply = is_op(\"multiply\")(self.x, self.scale)\n",
    "        self.weight = is_const()\n",
    "        self.conv = is_op(\"nn.conv2d\")(self.multiply, self.weight)\n",
    "        self.pattern = self.conv\n",
    "\n",
    "    def callback(self, pre, post, matches):\n",
    "        x_ = matches[self.x][0]\n",
    "        w_ = matches[self.weight][0]\n",
    "        w_ = w_ * matches[self.scale][0]\n",
    "        conv = matches[self.conv][0]\n",
    "        o = relay.conv2d(x_, w_, **conv.attrs)\n",
    "        return o\n",
    "\n",
    "@tvm.ir.transform.module_pass(opt_level=3, name=\"MulConv2dRewriter\")\n",
    "class MulConv2dRewriterPipeline:\n",
    "    def transform_module(self, mod, ctx):\n",
    "        mod[\"main\"] = rewrite(RewriteMulConv2d(), mod[\"main\"])\n",
    "        mod = relay.transform.FoldConstant()(mod)\n",
    "        return mod\n",
    "\n",
    "class M(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = torch.nn.Conv2d(3, 16, 3, bias=False)\n",
    "        self.conv2 = torch.nn.Conv2d(16, 32, 1, bias=False)\n",
    "        self.scale = torch.tensor(4.019027233123779, dtype=torch.float32)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = F.interpolate(\n",
    "            x,\n",
    "            size=None,\n",
    "            scale_factor=(0.5, 0.5),\n",
    "            mode=\"nearest\",\n",
    "            # align_corners=False,\n",
    "        )\n",
    "        x = self.scale * x \n",
    "        x = self.conv2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "torch_model = M()\n",
    "shape = (1, 3, 10, 10)\n",
    "input_shapes = [(\"x\", shape)]\n",
    "with torch.no_grad():\n",
    "    trace = torch.jit.trace(torch_model, [torch.randn(*shape)])\n",
    "    mod, params = relay.frontend.from_pytorch(trace, input_shapes)\n",
    "print(mod[\"main\"])\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    mod = relay.quantize.prerequisite_optimize(mod, params)\n",
    "    mod = MulConv2dRewriterPipeline()(mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%x: Tensor[(1, 3, 10, 10), float32] /* ty=Tensor[(1, 3, 10, 10), float32] span=aten::_convolution_0.x:0:0 */) -> Tensor[(1, 32, 4, 4), float32] {\n",
      "  %0 = nn.conv2d(%x, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 8, 8), float32] span=aten::_convolution_0:0:0 */;\n",
      "  %1 = image.resize2d(%0, size=[4, 4], roi=[0f, 0f, 0f, 0f], method=\"nearest_neighbor\", coordinate_transformation_mode=\"asymmetric\", rounding_method=\"\", cubic_alpha=-0.75f) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten::upsample_nearest2d_0:0:0 */;\n",
      "  nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(32, 16, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 4, 4), float32] */\n",
      "} /* ty=fn (Tensor[(1, 3, 10, 10), float32]) -> Tensor[(1, 32, 4, 4), float32] */\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(mod[\"main\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "inputs_np = [np.random.rand(1, 3, 10, 10).astype(\"float32\")]\n",
    "evaluate = relay.create_executor(\"debug\", mod=mod, device=tvm.cpu(), target=\"llvm\").evaluate()\n",
    "tvm_output = evaluate(inputs_np[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    torch_output = torch_model(torch.from_numpy(inputs_np[0]))\n",
    "np.testing.assert_allclose(tvm_output.numpy(), torch_output.numpy(), rtol=1e-7, atol=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relay.analysis.analysis import extract_intermdeiate_expr\n",
    "mod = extract_intermdeiate_expr(mod, 35)\n",
    "mod = relay.transform.InferType()(mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%input.1: Tensor[(1, 3, 128, 128), float32] /* ty=Tensor[(1, 3, 128, 128), float32] span=Conv_9.input.1:0:0 */) -> Tensor[(1, 48, 64, 64), float32] {\n",
      "  %0 = nn.conv2d(%input.1, meta[relay.Constant][0] /* ty=Tensor[(32, 3, 3, 3), float32] span=Conv_9.conv_in.weight:0:0 */, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_9:0:0 */;\n",
      "  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(32, 1, 1, 3), float32] span=Conv_10.body.0.body.0.dau_top.body.0.weight:0:0 */, padding=[0, 1, 0, 1], groups=32, channels=32, kernel_size=[1, 3]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_10:0:0 */;\n",
      "  %2 = nn.conv2d(%1, meta[relay.Constant][2] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_11.body.0.body.0.dau_top.body.1.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_11:0:0 */;\n",
      "  %3 = broadcast_to_like(meta[relay.Constant][3] /* ty=Tensor[(32, 1, 1), float32] span=PRelu_12.onnx::PRelu_175:0:0 */, %2) /* ty=Tensor[(1, 32, 128, 128), float32] span=PRelu_12:0:0 */;\n",
      "  %4 = reshape(%2, newshape=[-1]) /* ty=Tensor[(524288), float32] span=PRelu_12:0:0 */;\n",
      "  %5 = reshape(%3, newshape=[-1]) /* ty=Tensor[(524288), float32] span=PRelu_12:0:0 */;\n",
      "  %6 = nn.prelu(%4, %5, axis=0) /* ty=Tensor[(524288), float32] span=PRelu_12:0:0 */;\n",
      "  %7 = reshape(%6, newshape=[1, 32, 128, 128]) /* ty=Tensor[(1, 32, 128, 128), float32] span=PRelu_12:0:0 */;\n",
      "  %8 = nn.conv2d(%7, meta[relay.Constant][4] /* ty=Tensor[(32, 1, 1, 3), float32] span=Conv_13.body.0.body.0.dau_top.body.3.weight:0:0 */, padding=[0, 1, 0, 1], groups=32, channels=32, kernel_size=[1, 3]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_13:0:0 */;\n",
      "  %9 = nn.conv2d(%8, meta[relay.Constant][5] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_14.body.0.body.0.dau_top.body.4.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_14:0:0 */;\n",
      "  %10 = nn.global_avg_pool2d(%9) /* ty=Tensor[(1, 32, 1, 1), float32] span=GlobalAveragePool_15:0:0 */;\n",
      "  %11 = nn.conv2d(%10, meta[relay.Constant][6] /* ty=Tensor[(2, 32, 1, 1), float32] span=Conv_16.body.0.body.0.dau_top.gcnet.se.1.weight:0:0 */, padding=[0, 0, 0, 0], channels=2, kernel_size=[1, 1]) /* ty=Tensor[(1, 2, 1, 1), float32] span=Conv_16:0:0 */;\n",
      "  %12 = broadcast_to_like(meta[relay.Constant][7] /* ty=Tensor[(2, 1, 1), float32] span=PRelu_17.onnx::PRelu_176:0:0 */, %11) /* ty=Tensor[(1, 2, 1, 1), float32] span=PRelu_17:0:0 */;\n",
      "  %13 = reshape(%11, newshape=[-1]) /* ty=Tensor[(2), float32] span=PRelu_17:0:0 */;\n",
      "  %14 = reshape(%12, newshape=[-1]) /* ty=Tensor[(2), float32] span=PRelu_17:0:0 */;\n",
      "  %15 = nn.prelu(%13, %14, axis=0) /* ty=Tensor[(2), float32] span=PRelu_17:0:0 */;\n",
      "  %16 = reshape(%15, newshape=[1, 2, 1, 1]) /* ty=Tensor[(1, 2, 1, 1), float32] span=PRelu_17:0:0 */;\n",
      "  %17 = nn.conv2d(%16, meta[relay.Constant][8] /* ty=Tensor[(32, 2, 1, 1), float32] span=Conv_18.body.0.body.0.dau_top.gcnet.se.3.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 1, 1), float32] span=Conv_18:0:0 */;\n",
      "  %18 = sigmoid(%17) /* ty=Tensor[(1, 32, 1, 1), float32] span=Sigmoid_19:0:0 */;\n",
      "  %19 = nn.conv2d(%18, meta[relay.Constant][9] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_20.body.0.body.0.dau_top.gcnet.channel_add_conv.0.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 1, 1), float32] span=Conv_20:0:0 */;\n",
      "  %20 = broadcast_to_like(meta[relay.Constant][10] /* ty=Tensor[(32, 1, 1), float32] span=PRelu_21.onnx::PRelu_177:0:0 */, %19) /* ty=Tensor[(1, 32, 1, 1), float32] span=PRelu_21:0:0 */;\n",
      "  %21 = reshape(%19, newshape=[-1]) /* ty=Tensor[(32), float32] span=PRelu_21:0:0 */;\n",
      "  %22 = reshape(%20, newshape=[-1]) /* ty=Tensor[(32), float32] span=PRelu_21:0:0 */;\n",
      "  %23 = nn.prelu(%21, %22, axis=0) /* ty=Tensor[(32), float32] span=PRelu_21:0:0 */;\n",
      "  %24 = reshape(%23, newshape=[1, 32, 1, 1]) /* ty=Tensor[(1, 32, 1, 1), float32] span=PRelu_21:0:0 */;\n",
      "  %25 = nn.conv2d(%24, meta[relay.Constant][11] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_22.body.0.body.0.dau_top.gcnet.channel_add_conv.2.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 1, 1), float32] span=Conv_22:0:0 */;\n",
      "  %26 = add(%9, %25) /* ty=Tensor[(1, 32, 128, 128), float32] span=Add_23:0:0 */;\n",
      "  %27 = broadcast_to_like(meta[relay.Constant][3] /* ty=Tensor[(32, 1, 1), float32] span=PRelu_12.onnx::PRelu_175:0:0 */, %26) /* ty=Tensor[(1, 32, 128, 128), float32] span=PRelu_24:0:0 */;\n",
      "  %28 = reshape(%26, newshape=[-1]) /* ty=Tensor[(524288), float32] span=PRelu_24:0:0 */;\n",
      "  %29 = reshape(%27, newshape=[-1]) /* ty=Tensor[(524288), float32] span=PRelu_24:0:0 */;\n",
      "  %30 = nn.prelu(%28, %29, axis=0) /* ty=Tensor[(524288), float32] span=PRelu_24:0:0 */;\n",
      "  %31 = reshape(%30, newshape=[1, 32, 128, 128]) /* ty=Tensor[(1, 32, 128, 128), float32] span=PRelu_24:0:0 */;\n",
      "  %32 = add(%31, %0) /* ty=Tensor[(1, 32, 128, 128), float32] span=Add_25:0:0 */;\n",
      "  %33 = image.resize2d(%32, size=[64, 64], roi=[0f, 0f, 0f, 0f], method=\"nearest_neighbor\", coordinate_transformation_mode=\"asymmetric\", rounding_method=\"floor\", cubic_alpha=-0.75f) /* ty=Tensor[(1, 32, 64, 64), float32] span=Resize_27:0:0 */;\n",
      "  %34 = multiply(4.01903f /* ty=float32 span=Mul_28.body.0.body.0.down2.0.alpha:0:0 */, %33) /* ty=Tensor[(1, 32, 64, 64), float32] span=Mul_28:0:0 */;\n",
      "  nn.conv2d(%34, meta[relay.Constant][12] /* ty=Tensor[(48, 32, 1, 1), float32] span=Conv_29.body.0.body.0.down2.1.weight:0:0 */, padding=[0, 0, 0, 0], channels=48, kernel_size=[1, 1]) /* ty=Tensor[(1, 48, 64, 64), float32] span=Conv_29:0:0 */\n",
      "} /* ty=fn (Tensor[(1, 3, 128, 128), float32]) -> Tensor[(1, 48, 64, 64), float32] */\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(mod[\"main\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xxx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
